using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System;

namespace NetCorePal.Extensions.CodeAnalysis.SourceGenerators;

[Generator]
public class EntityMetadataGenerator : IIncrementalGenerator
{
    public void Initialize(IncrementalGeneratorInitializationContext context)
    {
        var classDeclarations = context.SyntaxProvider
            .CreateSyntaxProvider(
                predicate: (node, _) => node is ClassDeclarationSyntax,
                transform: (ctx, _) => (ClassDeclarationSyntax)ctx.Node)
            .Where(c => c != null);

        var compilationAndClasses = context.CompilationProvider.Combine(classDeclarations.Collect());

        context.RegisterSourceOutput(compilationAndClasses, (spc, source) =>
        {
            var (compilation, classes) = source;

            // 只收集 public 且实现 IsEntity 的类型
            var entityClasses = new List<(INamedTypeSymbol Symbol, ClassDeclarationSyntax Decl, SemanticModel Model)>();
            foreach (var classDecl in classes)
            {
                var model = compilation.GetSemanticModel(classDecl.SyntaxTree);
                var symbol = model.GetDeclaredSymbol(classDecl) as INamedTypeSymbol;
                if (symbol != null && (classDecl.Modifiers.Any(m => m.IsKind(Microsoft.CodeAnalysis.CSharp.SyntaxKind.PublicKeyword)) || symbol.DeclaredAccessibility == Accessibility.Public))
                {
                    if (symbol.IsEntity())
                    {
                        entityClasses.Add((symbol, classDecl, model));
                    }
                }
            }

            // 只收集每个实体本身，不递归子实体的子实体
            var allEntities = new Dictionary<string, (INamedTypeSymbol Symbol, ClassDeclarationSyntax Decl, SemanticModel Model, bool IsAggregateRoot)>();
            foreach (var (symbol, decl, model) in entityClasses)
            {
                var isAggregateRoot = symbol.IsAggregateRoot();
                var key = symbol.ToDisplayString();
                if (!allEntities.ContainsKey(key))
                {
                    allEntities[key] = (symbol, decl, model, isAggregateRoot);
                }
            }


            // 只生成 EntityMetadataAttribute
            var sb = new StringBuilder();
            sb.AppendLine("// <auto-generated/>\nusing System;\nusing NetCorePal.Extensions.CodeAnalysis.Attributes;");
            foreach (var (symbol, decl, model, isAggregateRoot) in allEntities.Values)
            {
                var typeName = symbol.ToDisplayString();
                var subEntities = new List<string>();
                // 属性（所有可见性）
                foreach (var prop in decl.Members.OfType<PropertyDeclarationSyntax>())
                {
                    var propType = model.GetTypeInfo(prop.Type).Type as INamedTypeSymbol;
                    if (propType != null && propType.IsEntity() && !propType.Equals(symbol, SymbolEqualityComparer.Default))
                        subEntities.Add(propType.ToDisplayString());
                    else if (propType != null && propType.IsGenericType && propType.TypeArguments.Length == 1)
                    {
                        var elementType = propType.TypeArguments[0] as INamedTypeSymbol;
                        if (elementType != null && elementType.IsEntity() && !elementType.Equals(symbol, SymbolEqualityComparer.Default))
                            subEntities.Add(elementType.ToDisplayString());
                    }
                }
                // 字段（所有可见性）
                foreach (var field in decl.Members.OfType<FieldDeclarationSyntax>())
                {
                    var fieldType = model.GetTypeInfo(field.Declaration.Type).Type as INamedTypeSymbol;
                    if (fieldType != null && fieldType.IsEntity() && !fieldType.Equals(symbol, SymbolEqualityComparer.Default))
                        subEntities.Add(fieldType.ToDisplayString());
                    else if (fieldType != null && fieldType.IsGenericType && fieldType.TypeArguments.Length == 1)
                    {
                        var elementType = fieldType.TypeArguments[0] as INamedTypeSymbol;
                        if (elementType != null && elementType.IsEntity() && !elementType.Equals(symbol, SymbolEqualityComparer.Default))
                            subEntities.Add(elementType.ToDisplayString());
                    }
                }
                // 方法（所有方法，包括私有、静态、实例、构造函数等）
                var methodNames = new List<string>();
                // 普通方法
                methodNames.AddRange(decl.Members.OfType<MethodDeclarationSyntax>().Select(m => m.Identifier.Text));
                // 构造函数统一用 .ctor
                if (decl.Members.OfType<ConstructorDeclarationSyntax>().Any(c => !c.Modifiers.Any(m => m.IsKind(Microsoft.CodeAnalysis.CSharp.SyntaxKind.StaticKeyword))))
                {
                    methodNames.Add(".ctor");
                }
                // 静态构造函数
                if (decl.Members.OfType<ConstructorDeclarationSyntax>().Any(c => c.Modifiers.Any(m => m.IsKind(Microsoft.CodeAnalysis.CSharp.SyntaxKind.StaticKeyword))))
                {
                    methodNames.Add(".cctor");
                }
                var subEntitiesDistinct = subEntities.Distinct().ToList();
                var methodNamesDistinct = methodNames.Distinct().ToList();
                var subEntitiesLiteral = string.Join(", ", subEntitiesDistinct.Select(e => $"\"{e}\""));
                var methodNamesLiteral = string.Join(", ", methodNamesDistinct.Select(m => $"\"{m}\""));
                sb.AppendLine($"[assembly: EntityMetadataAttribute(\"{typeName}\", {isAggregateRoot.ToString().ToLower()}, new string[] {{ {subEntitiesLiteral} }}, new string[] {{ {methodNamesLiteral} }})]");
            }
            spc.AddSource("EntityMetadata.g.cs", sb.ToString());

            // 已用 GeneratorExtensions.IsEntity 替代 IsChildEntityType
        });
    }
} 